import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from lime import lime_image
from skimage.segmentation import mark_boundaries
import pandas as pd
# 1. Load and Preprocess Data (Fashion MNIST)
# ---------------------------------------------------------
path = '/Users/hassothea/.cache/kagglehub/datasets/zalando-research/fashionmnist/versions/4'
data_test = pd.read_csv(path + '/fashion-mnist_test.csv')
data_train = pd.read_csv(path + '/fashion-mnist_train.csv')
x_train, y_train = data_train.iloc[:,1:], data_train['label']
x_test, y_test = data_test.iloc[:,1:], data_test['label']
# Normalize and reshape to (28, 28, 1) for the model
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
# Class names for Fashion MNIST
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 2. Build or Load Your Model
# ---------------------------------------------------------
# (Here we build a quick dummy model for demonstration)
import os
# Define a filename for your model
model_path = './models/fashion_mnist_lime_model.keras'
# Check if the model already exists
if os.path.exists(model_path):
print(f"Found saved model at '{model_path}'. Loading...")
model = keras.models.load_model(model_path)
else:
print("No saved model found. Training now...")
# Build the model
model = keras.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Train the model
model.fit(x_train, y_train, epochs=5, batch_size=64, validation_split=0.1)
# Save the model for next time
model.save(model_path)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=1, batch_size=64, validation_split=0.1) # Short training for demo
# 3. Define the Prediction Wrapper for LIME
# ---------------------------------------------------------
def to_grayscale_predict(images_rgb):
"""
LIME passes 3-channel RGB images.
We must convert them back to 1-channel grayscale for the model.
"""
# 1. Convert RGB (N, 28, 28, 3) -> Grayscale (N, 28, 28, 1)
# Since LIME just replicates the gray channel, we can just take the first channel.
images_gray = images_rgb[:, :, :, :1]
# 2. Predict
return model.predict(images_gray, verbose=0)
# 4. Setup LIME Explainer
# ---------------------------------------------------------
explainer = lime_image.LimeImageExplainer()
# Select a test image to explain
idx = 89 # Change this index to see different examples
img_to_explain_gray = x_test[idx,:].reshape(28,28,1) # Shape (28, 28, 1)
# IMPORTANT: LIME expects the input image to have 3 channels (RGB).
# We duplicate the grayscale channel 3 times to satisfy LIME.
img_to_explain_rgb = np.repeat(img_to_explain_gray, 3, axis=2)
# Generate Explanation
# top_labels=5: analyze the top 5 predicted classes
# hide_color=0: replace perturbed superpixels with black (0)
# num_samples=1000: number of perturbations to generate
# 1. Pixel-wise segmentation function (Crucial for 28x28 images)
# 1. Pixel-wise segmentation function (Crucial for 28x28 images)
def pixel_segmentation(image):
row_idx, col_idx = np.indices((28, 28))
return row_idx * 28 + col_idx
# 2. Generate the Explanation
# We ask LIME to find the top 5 labels
explanation = explainer.explain_instance(
img_to_explain_rgb,
to_grayscale_predict,
top_labels=5,
hide_color=0, # Mask with black (0) for MNIST
num_samples=1000,
segmentation_fn=pixel_segmentation
)
# 3. Retrieve the Label from the Explanation Object
# ---------------------------------------------------------
# explanation.top_labels contains the class indices sorted by model probability.
# Index [0] is the class with the highest probability (the winner).
# Index [1] would be the runner-up, etc.
target_label = explanation.top_labels[0]
print(f"LIME identified class {target_label} ({class_names[target_label]}) as the top prediction.")
# 4. Get image and mask using that internal label
temp, mask = explanation.get_image_and_mask(
target_label, # Use the label extracted from explanation
positive_only=False,
num_features=5,
hide_rest=False
)
# 5. Plot
plt.figure(figsize=(8, 4))
# Original
plt.subplot(1, 2, 1)
plt.imshow(img_to_explain_gray.squeeze(), cmap='BuGn')
plt.title(f"True Label: {class_names[y_test[idx]]}")
plt.axis('off')
# Explanation
plt.subplot(1, 2, 2)
plt.imshow(mark_boundaries(temp, mask), cmap='BuGn')
plt.title(f"Explanation for '{class_names[target_label]}'")
plt.axis('off')
plt.tight_layout()
plt.show()